--WAF Action
require 'config'
require 'lib'
local db = require "sqllite3"
--args
local rulematch = ngx.re.find
local unescape = ngx.unescape_uri

local function get_ip_position_data(ip)
    local ok ,geo = pcall(function()
        return  require 'maxminddb'
    end)
    if not ok then
        return nil
    end
    if type(geo)=='number' then return nil end
    if not geo.initted() then
        geo.init(config_GeoLite2_path)
    end
    local res,_=geo.lookup(ip or ngx.var.remote_addr)
    if not res then
        return nil
    else
        return res
    end
end

function is_intranet_address(ips)
    if not ips then return false end
    if ips=='unknown' then return false end
    if string.find(ips,':') then return false end
    ips = arrip(ips)
    if (not is_max(ips,arrip("192.168.255.255")) and not is_min(ips,arrip("192.168.0.1"))) and  (not is_max(ips,arrip("172.16.255.255")) and not is_min(ips,arrip("172.16.0.1"))) and (not is_max(ips,arrip("10.255.255.255")) and not is_min(ips,arrip("10.0.0.1")))  then return false end
    return true
end

function split( str,reps )
    if str ==nil then return nil end
    local resultStrList = {}
    string.gsub(str,'[^'..reps..']+',function(w)
        table.insert(resultStrList,w)
    end)
    return resultStrList
end

function arrip(ipstr)
    if ipstr == 'unknown' then return {0,0,0,0} end
    if string.find(ipstr,':') then return ipstr end
    iparr = split(ipstr,'.')
    iparr[1] = tonumber(iparr[1])
    iparr[2] = tonumber(iparr[2])
    iparr[3] = tonumber(iparr[3])
    iparr[4] = tonumber(iparr[4])
    return iparr
end

function is_min(ip1,ip2)
    if not ip1 then return false end
    if not ip2 then return false end
    n = 0
    for _,v in ipairs({1,2,3,4})
    do
        if  not ip1[v] then return false end
        if  not ip2[v] then return false end
        if ip1[v] == ip2[v] then
            n = n + 1
        elseif ip1[v] > ip2[v] then
            break
        else
            return false
        end
    end
    return true
end

function is_max(ip1,ip2)
    if not ip1 then return false end
    if not ip2 then return false end
    n = 0
    for _,v in ipairs({1,2,3,4})
    do
        if  not ip1[v] then return false end
        if  not ip2[v] then return false end
        if ip1[v] == ip2[v] then
            n = n + 1
        elseif ip1[v] < ip2[v] then
            break
        else
            return false
        end
    end
    return true
end

function waf_init_db()
    local sqlstr = [[ CREATE TABLE IF NOT EXISTS totla_log (
				id INTEGER PRIMARY KEY AUTOINCREMENT,
				time INTEGER,
				time_localtime TEXT,
				server_name TEXT,
				ip TEXT,
				ip_city TEXT,
				ip_country TEXT,
				ip_subdivisions TEXT,
				ip_continent TEXT,
				ip_longitude TEXT,
				ip_latitude TEXT,
				type TEXT,
				uri TEXT,
				user_agent TEXT,
				filter_rule TEXT,
				incoming_value TEXT,
			    value_risk TEXT,
				http_log TEXT,
				http_log_path INTEGER
			); ]]
    db.exec(sqlstr)
end

function table_log_insert(filter_rule,rule,name)
    waf_init_db()
    get_ip_position=get_ip_position_data(get_client_ip())
    if 	get_ip_position==nil then
        if  is_intranet_address(get_client_ip()) then
            ip_city=''
            ip_country='内网地址'
            ip_subdivisions=''
            ip_continent=''
            ip_longitude=''
            ip_latitude=''
        else
            ip_city=''
            ip_country='未知位置'
            ip_subdivisions=''
            ip_continent=''
            ip_longitude=''
            ip_latitude=''
        end
    else
        if get_ip_position['city'] then
            if get_ip_position['city']['names'] then
                ip_city=get_ip_position['city']['names']['zh-CN']
            else
                ip_city=''
            end
        end
        if get_ip_position['country'] then
            if get_ip_position['country']['names'] then
                ip_country=get_ip_position['country']['names']['zh-CN']
            else
                ip_country=''
            end
        end
        if get_ip_position['subdivisions'] then
            if get_ip_position['subdivisions'][1] then
                ip_subdivisions=get_ip_position['subdivisions'][1]['names']['zh-CN']
            else
                ip_subdivisions=''
            end
        end
        if get_ip_position['continent'] then
            if get_ip_position['continent'] then
                ip_continent=get_ip_position['continent']['names']['zh-CN']
            else
                ip_continent=''
            end
        end
        if get_ip_position['location'] then
            if get_ip_position['location'] then
                ip_longitude=get_ip_position['location']['longitude']
            else
                ip_longitude=''
            end

        end
        if get_ip_position['location'] then
            if get_ip_position['location'] then
                ip_latitude=get_ip_position['location']['latitude']
            else
                ip_latitude=''
            end
        end
    end
    local nowTime = os.time()
    -- 获取时区
    local sq = os.difftime(nowTime, os.time(os.date("!*t", nowTime)))/3600
    nowTime = nowTime - sq*3600
    --能够单独转换
    local tb = {}
    tb.year = tonumber(os.date("%Y",nowTime))
    tb.month =tonumber(os.date("%m",nowTime))
    tb.day = tonumber(os.date("%d",nowTime))
    tb.hour = 0
    tb.minute = 0
    tb.second = 0
    local daytime = os.time(tb)
    local sql = string.format("INSERT INTO totla_log(time,time_localtime,server_name,ip, ip_city,ip_country,ip_subdivisions,ip_continent,ip_longitude,ip_latitude,type,uri,user_agent,filter_rule,incoming_value,value_risk,http_log,http_log_path) VALUES(%s,'%s','%s','%s','%s','%s','%s','%s','%s','%s','%s','%s','%s','%s','%s','%s','%s','%s');",daytime,ngx.localtime(),ngx.var.server_name,get_client_ip(),ip_city,ip_country,ip_subdivisions,ip_continent,ip_latitude,ip_latitude,ngx.req.get_method(),ngx.var.request_uri,get_user_agent(),filter_rule,rule,name,ngx.unescape_uri(ngx.var.uri),0)
    db.exec(sql)
end

--allow white ip
function white_ip_check()
     if config_white_ip_check == "on" then
        local IP_WHITE_RULE = read_rule('whiteip')
        local WHITE_IP = get_client_ip()
        if IP_WHITE_RULE ~= nil then
            for _,rule in pairs(IP_WHITE_RULE) do
                if rule ~= "" and rulematch(WHITE_IP,rule,"isjo") then
                    log_record('White_IP',ngx.var_request_uri,"_","_")
                    return true
                end
            end
        end
    end
end

--deny black ip
function black_ip_check()
     if config_black_ip_check == "on" then
        local IP_BLACK_RULE = read_rule('blackip')
        local BLACK_IP = get_client_ip()
        if IP_BLACK_RULE ~= nil then
            for _,rule in pairs(IP_BLACK_RULE) do
                if rule ~= "" and rulematch(BLACK_IP,rule,"isjo") then
                    table_log_insert('黑名单ip',rule,'blackip')
                    log_record('BlackList_IP',ngx.var_request_uri,"_","_")
                    if config_waf_enable == "on" then
                        ngx.exit(403)
                        return true
                    end
                end
            end
        end
    end
end

--allow white url
function white_url_check()
    if config_white_url_check == "on" then
        local URL_WHITE_RULES = read_rule('whiteurl')
        local REQ_URI = ngx.var.request_uri
        if URL_WHITE_RULES ~= nil then
            for _,rule in pairs(URL_WHITE_RULES) do
                if rule ~= "" and rulematch(REQ_URI,rule,"isjo") then
                    return true
                end
            end
        end
    end
end

--deny cc attack
function cc_attack_check()
    if config_cc_check == "on" then
        local cip = get_client_ip()
        local ATTACK_URI=ngx.var.uri
        local CC_TOKEN = cip..ATTACK_URI
        local limit = ngx.shared.limit
        CCcount=tonumber(string.match(config_cc_rate,'(.*)/'))
        CCseconds=tonumber(string.match(config_cc_rate,'/(.*)'))
        local safe_count,_ = limit:get(cip)
        if not safe_count then
            limit:set(cip,1,86400)
            safe_count = 1
        end
        local lock_time = (CCseconds * (safe_count + safe_count))
        if lock_time > 86400 then lock_time = 86400 end
        local req,_ = limit:get(CC_TOKEN)
        if req then
            if req > CCcount then
                local sf_status = cip.."status"
                local safe_status,_ = limit:get(sf_status)
                if not safe_status then
                    limit:set(sf_status,1,lock_time)
                    limit:incr(cip,1)
                    log_record('CC_Attack',ngx.var.request_uri,lock_time,"-")
                end
                if config_waf_enable == "on" then
                    ngx.exit(403)
                end
            else
                limit:incr(CC_TOKEN,1)
            end
        else
            limit:set(CC_TOKEN,1,lock_time)
        end
    end
    return false
end

--deny cookie
function cookie_attack_check()
    if config_cookie_check == "on" then
        local COOKIE_RULES = read_rule('cookie')
        if COOKIE_RULES == nil then
            return false
        end
        local USER_COOKIE = ngx.var.http_cookie
        if USER_COOKIE ~= nil then
            for ke,rule in pairs(COOKIE_RULES) do
                if rule ~="" and rulematch(USER_COOKIE,rule,"isjo") then
                    local rulenote = select_rule_note(ke)
                    table_log_insert(rulenote,rule,'cookie')
                    log_record('Deny_Cookie',ngx.var.request_uri,"-",rule)
                    if config_waf_enable == "on" then
                        waf_output()
                        return true
                    end
                end
             end
	 end
    end
    return false
end

--deny url
function url_attack_check()
    if config_url_check == "on" then
        local URL_RULES = read_rule('url')
        if URL_RULES == nil then
            return false
        end
        local REQ_URI = ngx.var.request_uri
        for ke,rule in pairs(URL_RULES) do
            if rule ~="" and rulematch(REQ_URI,rule,"isjo") then
                local rulenote = select_rule_note(ke)
                table_log_insert(rulenote,rule,'url')
                log_record('Deny_URL',REQ_URI,"-",rule)
                if config_waf_enable == "on" then
                    waf_output()
                    return true
                end
            end
        end
    end
    return false
end

--deny url args
function url_args_attack_check()
    if config_url_args_check == "on" then
        local ARGS_RULES = read_rule('args')
        if ARGS_RULES == nil then
            return false
        end
        for ke,rule in pairs(ARGS_RULES) do
            local REQ_ARGS = ngx.req.get_uri_args()
            for _, val in pairs(REQ_ARGS) do
                if type(val) == 'table' then
                    ARGS_DATA = table.concat(val, " ")
                else
                    ARGS_DATA = val
                end
                log_record('Deny_URL_Args',ngx.var.request_uri,unescape(ARGS_DATA),rule)
                if ARGS_DATA and type(ARGS_DATA) ~= "boolean" and rule ~="" and rulematch(unescape(ARGS_DATA),rule,"isjo") then
                    local rulenote = select_rule_note(ke)
                    table_log_insert(rulenote,rule,'args')
                    log_record('Deny_URL_Args',ngx.var.request_uri,"-",rule)
                    if config_waf_enable == "on" then
                        waf_output()
                        return true
                    end
                end
            end
        end
    end
    return false
end
--deny user agent
function user_agent_attack_check()
    if config_user_agent_check == "on" then
        local USER_AGENT_RULES = read_rule('useragent')
        if USER_AGENT_RULES == nil then
            return false
        end
        local USER_AGENT = ngx.var.http_user_agent
        if USER_AGENT ~= nil then
            for ke,rule in pairs(USER_AGENT_RULES) do
                if rule ~="" and rulematch(USER_AGENT,rule,"isjo") then
                    local rulenote = select_rule_note(ke)
                    table_log_insert(rulenote,rule,'useragent')
                    log_record('Deny_USER_AGENT',ngx.var.request_uri,"-",rule)
                    if config_waf_enable == "on" then
                        waf_output()
                        return true
                    end
                end
            end
        end
    end
    return false
end

local function _process_json_args(json_args,t)
    if type(json_args)~='table' then return {} end
    local t = t or {}
    for k,v in pairs(json_args) do
        if type(v) == 'table' then
            for _,_v in pairs(v) do
                if type(_v) == "table" then
                    t = _process_json_args(_v,t)

                else
                    if type(t[k]) == "table" then
                        table.insert(t[k],_v)

                    elseif type(t[k]) == "string" then
                        local tmp = {}
                        table.insert(tmp,t[k])
                        table.insert(tmp,_v)
                        t[k] = tmp
                    else

                        t[k] = _v
                    end
                end

            end
        else
            if type(t[k]) == "table" then
                table.insert(t[k],v)
            elseif type(t[k]) == "string" then
                local tmp = {}
                table.insert(tmp,t[k])
                table.insert(tmp,v)
                t[k] = tmp
            else

                t[k] = v
            end
        end
    end
    return t
end

function continue_key(key)
    if ngx.req.get_method()~='POST' then return true end
    key = tostring(key)
    if string.len(key) > 64 then return false end;
    local keys = {"content","contents","body","msg","file","files","img","newcontent","message","subject","kw","srchtxt",""}
    for _,k in ipairs(keys)
    do
        if k == key then return false end;
    end
    return true;
end

function is_ngx_match_urlencoded(rules,sbody)
    if rules == nil or sbody == nil then return false end
    if type(sbody) == "string" then
        sbody = {sbody}
    end
    if type(rules) == "string" then
        rules = {rules}
    end
    for k,body in pairs(sbody)
    do
        if continue_key(k)  then
            for _,rule in ipairs(rules)
            do
                if body and rule ~="" then
                    if type(body) == "string" then
                        if rulematch(ngx.unescape_uri(body),rule,"isjo") then
                            if ngx.req.get_method() ~="POST" and  rule=="'$" then
                                return false
                            end
                            if config_waf_enable == "on" then
                                table_log_insert('POST参数过滤',rule,'post')
                                waf_output()
                                return true
                            end
                        end
                    end
                    if type(k) == "string" then
                        if rulematch(ngx.unescape_uri(k),rule,"isjo") then
                            if config_waf_enable == "on" then
                                table_log_insert('POST参数过滤',rule,'post')
                                waf_output()
                                return true
                            end
                        end
                    end
                end
            end
        end
    end
    return false
end

function post_attack_check()
    if config_post_check ~= "on" then return false end
    if ngx.req.get_method() == "GET" then return false end
    request_header = ngx.req.get_headers(20000)
    content_length=tonumber(request_header['content-length'])
    if content_length == nil then return false end
    local content_type = request_header["content-type"]
    if not content_type then return false end
    if type(content_type)~='string' then
        if config_waf_enable == "on" then
            table_log_insert('http包非法',content_type,'post')
            waf_output()
            return true
        end
    end
    if content_type and ngx.re.find(content_type, 'multipart',"oij") then return false end
    ngx.req.read_body()
    request_args = ngx.req.get_post_args(1000000)
    if not request_args then
        if content_length >10000 then
            local check_html = [[<html><meta charset="utf-8" /><title>Nginx缓冲区溢出</title><div>WAF提醒您,Nginx缓冲区溢出,传递的参数超过接受参数的大小,出现异常,<br>第一种解决方案:把当前url-->]]..'^'..ngx.var.request_uri..[[加入到URL白名单中,如有疑问请联系官方运维QQ</br>第二种解决方案:面板-->nginx管理->性能调整-->client_body_buffer_size的值调整为10240K 或者5024K(PS:可能会一直请求失败建议加入白名单)</br></div></html>]]
            ngx.header.content_type = "text/html;charset=utf8"
            ngx.header.Cache_Control = "no-cache"
            ngx.say(check_html)
            ngx.exit(200)
        end
        return true
    end
    if type(request_args)=='table' then
        for _,v in pairs(request_args)
        do
            if type(v)=='string' then
                if not  string.find(v,'^data:.+/.+;base64,') then
                    if (#v) >=400000 then
                        log_record('Deny_post_Args',ngx.var.request_uri,"-",'参数值长度超过40w已被系统拦截')
                        if config_waf_enable == "on" then
                            waf_output()
                            return true
                        end
                    end
                else
                    kkkkk=ngx.re.match(v,'^data:.+;base64,','ijo')
                    if  kkkkk then
                        if kkkkk[0] then
                            if ngx.re.match(kkkkk[0],'php') or ngx.re.match(kkkkk[0],'jsp') then
                                if config_waf_enable == "on" then
                                    table_log_insert('拦截Bae64上传php文件',"php|jsp",'post')
                                    waf_output()
                                    return true
                                end
                            end
                        end
                    end
                end
            end
        end
    end
    local POST_RULES = read_rule('post')
    if content_type and  ngx.re.find(content_type, '^application/json',"oij") and ngx.req.get_headers(20000)["content-length"] and tonumber(ngx.req.get_headers(20000)["content-length"]) ~= 0 then
        local ok ,request_args = pcall(function()
            return json.decode(ngx.req.get_body_data())
        end)
        if not ok then
            local check_html = [[<html><meta charset="utf-8" /><title>json格式错误</title><div>请传递正确的json参数</div></html>]]
            ngx.header.content_type = "text/html;charset=utf8"
            ngx.header.Cache_Control = "no-cache"
            ngx.say(check_html)
            ngx.exit(200)
        end
        if type(request_args)~='table' then return false end
        request_args=_process_json_args(request_args)
        return is_ngx_match_urlencoded(POST_RULES,request_args)
    else
        return is_ngx_match_urlencoded(POST_RULES,request_args)
    end
    return false
end
